# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""Provides utilities for compiling and inspecting Mojo code.

This module contains functionality for compiling Mojo functions and examining
their assembly, LLVM IR, or object code output. It is particularly useful for
kernel engineers who want to inspect the low-level implementation details of
specific functions without dealing with entire files or manual invocation of
compilation tools.

Key features:
- Compile individual functions to assembly, LLVM IR, or object code
- Get linkage names and module information
- Inspect number of captures and other function metadata
- Write compilation output to files
- Control compilation options and targets

Example:
```mojo
from compile import compile_info

fn my_func(x: Int) -> Int:
    return x

# Get assembly for the function
info = compile_info[my_func]()
print(info)
```
"""

from collections.string.string_slice import _get_kgen_string
from os import PathLike
from pathlib import Path
from sys.info import CompilationTarget, _current_target, _TargetType

from .reflection import get_linkage_name

# ===-----------------------------------------------------------------------===#
# compile_info
# ===-----------------------------------------------------------------------===#


@register_passable("trivial")
struct _Info:
    """A compiled closure implementation.

    Internal struct used to store compilation results from MLIR.

    Attributes:
        asm: The generated assembly/IR code as a string
        num_captures: Number of captured variables
    """

    var asm: __mlir_type.`!kgen.string`
    var module_name: __mlir_type.`!kgen.string`
    var num_captures: __mlir_type.index
    var capture_sizes: UnsafePointer[UInt64, ImmutOrigin.external]


@register_passable("trivial")
struct _PopulateInfo:
    """A compiled populate closure implementation.

    This struct matches the type of the compiled closure type generated by
    `compile_offload`. Used internally for closure handling.

    Attributes:
        populate: Function pointer to populate captured variables
    """

    var populate: fn (__mlir_type.`!kgen.pointer<none>`) capturing -> None


@fieldwise_init
@register_passable("trivial")
struct CompiledFunctionInfo[
    func_type: AnyTrivialRegType,
    func: func_type,
    target: _TargetType,
](Stringable, Writable):
    """Contains compilation information and results for a function.

    Stores assembly/IR code, function metadata, and error information from
    compiling a function.

    Parameters:
        func_type: Type of the function being compiled.
        func: The function being compiled.
        target: The target architecture to compile for.

    Attributes:
        populate: Function to populate captures
    """

    var asm: StaticString
    """Generated assembly/IR code from the compilation process."""

    var function_name: StaticString
    """Mangled name of the compiled function, used for symbol resolution."""

    var module_name: StaticString
    """Name of the module containing the compiled function."""

    var num_captures: Int
    """Number of variables captured by the function closure."""

    var capture_sizes: UnsafePointer[UInt64, ImmutOrigin.external]
    """Pointer to the sizes of the variables captured by the function closure."""

    comptime populate = rebind[
        fn (OpaquePointer[MutAnyOrigin]) capturing -> None
    ](
        __mlir_attr[
            `#kgen.compile_offload_closure<`,
            Self.target,
            `,`,
            Self.func,
            `> : `,
            _PopulateInfo,
        ].populate
    )
    """Function pointer to populate captured variables in the function closure.
    """

    @no_inline
    fn write_to(self, mut writer: Some[Writer]):
        """Writes the assembly/IR to a writer.

        Args:
            writer: Writer object to write the assembly to.
        """
        return writer.write(self.asm)

    fn __str__(self) -> String:
        """Converts the assembly/IR to a string.

        Returns:
            The assembly/IR as a string.
        """
        return String.write(self)

    @no_inline
    fn write_text[path_like: PathLike](self, path: path_like) raises:
        """Writes the assembly/IR to a file.

        Parameters:
            path_like: Type that implements the `PathLike` interface for file
                path representation.

        Args:
            path: Path to write the file to.

        Raises:
            If file writing operations fail.
        """
        Path(path.__fspath__()).write_text(String(self))

    @no_inline
    fn __contains__(self, content: String) -> Bool:
        """Checks if content exists in the assembly/IR.

        Args:
            content: String to search for.

        Returns:
            True if content is found, False otherwise.
        """
        return content in String(self)


comptime _EMISSION_KIND_ASM = 0
comptime _EMISSION_KIND_LLVM = 1
comptime _EMISSION_KIND_LLVM_OPT = 2
comptime _EMISSION_KIND_OBJECT = 3
comptime _EMISSION_KIND_LLVM_BITCODE = 4
comptime _EMISSION_KIND_LLVM_OPT_BITCODE = 5


fn _get_emission_kind_id[emission_kind: StaticString]() -> Int:
    __comptime_assert emission_kind in [
        "asm",
        "llvm",
        "llvm-bitcode",
        "llvm-opt",
        "llvm-opt-bitcode",
        "object",
    ], (
        "invalid emission kind '"
        + emission_kind
        + "', must be one of 'asm', 'llvm', 'llvm-opt', or 'object'"
    )

    @parameter
    if emission_kind == "llvm":
        return _EMISSION_KIND_LLVM
    elif emission_kind == "llvm-bitcode":
        return _EMISSION_KIND_LLVM_BITCODE
    elif emission_kind == "llvm-opt":
        return _EMISSION_KIND_LLVM_OPT
    elif emission_kind == "llvm-opt-bitcode":
        return _EMISSION_KIND_LLVM_OPT_BITCODE
    elif emission_kind == "object":
        return _EMISSION_KIND_OBJECT
    else:
        return _EMISSION_KIND_ASM


@always_inline
fn compile_info[
    func_type: AnyTrivialRegType, //,
    func: func_type,
    /,
    *,
    emission_kind: StaticString = "asm",
    target: _TargetType = _current_target(),
    compile_options: StaticString = CompilationTarget[
        target
    ].default_compile_options(),
]() -> CompiledFunctionInfo[func_type, func, target]:
    """Compiles a function and returns detailed compilation information.

    This function takes a Mojo function and compiles it, providing access to the
    generated assembly code, linkage information, and other compilation
    artifacts. It can be used for inspection, debugging, and low-level
    optimization.

    Parameters:
        func_type: Type of the function to compile. Must be a trivially-copyable
            register type.
        func: The function to compile. Must match the specified func_type.
        emission_kind: The desired output format. Valid options are:
            - "asm": Assembly code (default).
            - "llvm": Unoptimized LLVM IR.
            - "llvm-opt": Optimized LLVM IR.
            - "object": Object code.
        target: The target architecture to compile for. Defaults to current
            architecture.
        compile_options: Additional compiler flags and options as a string.

    Returns:
        A `CompiledFunctionInfo` struct containing:
        - asm: The generated code in the requested format
        - linkage_name: The mangled function name for linking
        - module_hash: A unique hash of the compiled module
        - num_captures: Number of captured variables
        - error: Any error message (empty if successful)
        - failed: Boolean indicating if compilation failed

    Example:

        ```mojo
        from compile import compile_info

        fn my_func(x: Int) -> Int:
            return x

        info = compile_info[my_func]()
        print(info)  # Print assembly
        ```

    Note:
        The compilation is always performed, even if the function is not used.
        For performance-critical code, consider caching the compilation results.
    """

    var offload = __mlir_op.`kgen.compile_offload`[
        target_type=target,
        emission_kind = _get_emission_kind_id[emission_kind]()._mlir_value,
        emission_option = _get_kgen_string[compile_options](),
        func=func,
        _type=_Info,
    ]()

    return CompiledFunctionInfo[func_type, func, target](
        asm=StaticString(offload.asm),
        function_name=get_linkage_name[func, target=target](),
        module_name=StaticString(offload.module_name),
        num_captures=Int(mlir_value=offload.num_captures),
        capture_sizes=offload.capture_sizes,
    )
